import pandas as pd
import numpy as np
import torch
import tqdm

from main import tokenizer, BATCH_SIZE, model, DEVICE, get_location_predictions

ROOT = "./data/nbme-score-clinical-patient-notes"

def create_test_df():
    feats = pd.read_csv(f"{ROOT}/features.csv")
    notes = pd.read_csv(f"{ROOT}/patient_notes.csv")
    test = pd.read_csv(f"{ROOT}/test.csv")

    merged = test.merge(notes, how="left")
    merged = merged.merge(feats, how="left")

    def process_feature_text(text):
        return text.replace("-OR-", ";-").replace("-", " ")

    merged["feature_text"] = [process_feature_text(x) for x in merged["feature_text"]]

    print(merged.shape)
    return merged


class NBMETestData(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data.loc[idx]
        tokenized = self.tokenizer(
            example["feature_text"],
            example["pn_history"],
            truncation="only_second",
            max_length=416,
            padding="max_length",
            return_offsets_mapping=True
        )
        tokenized["sequence_ids"] = tokenized.sequence_ids()

        input_ids = np.array(tokenized["input_ids"])
        attention_mask = np.array(tokenized["attention_mask"])
        offset_mapping = np.array(tokenized["offset_mapping"])
        sequence_ids = np.array(tokenized["sequence_ids"]).astype("float16")

        return input_ids, attention_mask, offset_mapping, sequence_ids

test = create_test_df()
test_ds = NBMETestData(test, tokenizer)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size = BATCH_SIZE * 2, pin_memory = True, shuffle = False, drop_last = False)

model.eval()
preds = []
offsets = []
seq_ids = []
with torch.no_grad():
    for batch in tqdm(test_dl):
        input_ids = batch[0].to(DEVICE)
        attention_mask = batch[1].to(DEVICE)
        offset_mapping = batch[2]
        sequence_ids = batch[3]
        logits = model(input_ids, attention_mask)
        preds.append(logits.cpu().numpy())
        offsets.append(offset_mapping.numpy())
        seq_ids.append(sequence_ids.numpy())

preds = np.concatenate(preds, axis = 0)
offsets = np.concatenate(offsets, axis = 0)
seq_ids = np.concatenate(seq_ids, axis = 0)

location_preds = get_location_predictions(preds, offsets, seq_ids, test = True)
test["location"] = location_preds
test[["id", "location"]].to_csv("submission.csv", index = False)
pd.read_csv("submission.csv").head()